#!/usr/bin/env python
# coding: utf8

"""
    Module for building data preprocessing pipeline using the tensorflow
    data API. Data preprocessing such as audio loading, spectrogram
    computation, cropping, feature caching or data augmentation is done
    using a tensorflow dataset object that output a tuple (input_, output)
    where:

    -   input is a dictionary with a single key that contains the (batched)
        mix spectrogram of audio samples
    -   output is a dictionary of spectrogram of the isolated tracks
        (ground truth)
"""

import os
import time
from os.path import exists
from os.path import sep as SEPARATOR
from typing import Any, Dict, Optional

# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf

from .audio.adapter import AudioAdapter
from .audio.convertor import db_uint_spectrogram_to_gain, spectrogram_to_db_uint
from .audio.spectrogram import (
    compute_spectrogram_tf,
    random_pitch_shift,
    random_time_stretch,
)
from .utils.logging import logger
from .utils.tensor import (
    check_tensor_shape,
    dataset_from_csv,
    set_tensor_shape,
    sync_apply,
)

# pylint: enable=import-error

__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"

# Default audio parameters to use.
DEFAULT_AUDIO_PARAMS: Dict = {
    "instrument_list": ("vocals", "accompaniment"),
    "mix_name": "mix",
    "sample_rate": 44100,
    "frame_length": 4096,
    "frame_step": 1024,
    "T": 512,
    "F": 1024,
}


def get_training_dataset(
    audio_params: Dict, audio_adapter: AudioAdapter, audio_path: str
) -> Any:
    """
    Builds training dataset.

    Parameters:
        audio_params (Dict):
            Audio parameters.
        audio_adapter (AudioAdapter):
            Adapter to load audio from.
        audio_path (str):
            Path of directory containing audio.

    Returns:
        Any:
            Built dataset.
    """
    builder = DatasetBuilder(
        audio_params,
        audio_adapter,
        audio_path,
        chunk_duration=audio_params.get("chunk_duration", 20.0),
        random_seed=audio_params.get("random_seed", 0),
    )
    return builder.build(
        audio_params.get("train_csv"),
        cache_directory=audio_params.get("training_cache"),
        batch_size=audio_params.get("batch_size"),
        n_chunks_per_song=audio_params.get("n_chunks_per_song", 2),
        random_data_augmentation=False,
        convert_to_uint=True,
        wait_for_cache=False,
    )


def get_validation_dataset(
    audio_params: Dict, audio_adapter: AudioAdapter, audio_path: str
) -> Any:
    """
    Builds validation dataset.

    Parameters:
        audio_params (Dict):
            Audio parameters.
        audio_adapter (AudioAdapter):
            Adapter to load audio from.
        audio_path (str):
            Path of directory containing audio.

    Returns:
        Any:
            Built dataset.
    """
    builder = DatasetBuilder(
        audio_params, audio_adapter, audio_path, chunk_duration=12.0
    )
    return builder.build(
        audio_params.get("validation_csv"),
        batch_size=audio_params.get("batch_size"),
        cache_directory=audio_params.get("validation_cache"),
        convert_to_uint=True,
        infinite_generator=False,
        n_chunks_per_song=1,
        # should not perform data augmentation for eval:
        random_data_augmentation=False,
        random_time_crop=False,
        shuffle=False,
    )


class InstrumentDatasetBuilder(object):
    """ Instrument based filter and mapper provider. """

    def __init__(self, parent, instrument) -> None:
        """
        Default constructor.

        Parameters:
            parent:
                Parent dataset builder.
            instrument:
                Target instrument.
        """
        self._parent = parent
        self._instrument = instrument
        self._spectrogram_key = f"{instrument}_spectrogram"
        self._min_spectrogram_key = f"min_{instrument}_spectrogram"
        self._max_spectrogram_key = f"max_{instrument}_spectrogram"

    def load_waveform(self, sample):
        """ Load waveform for given sample. """
        return dict(
            sample,
            **self._parent._audio_adapter.load_tf_waveform(
                sample[f"{self._instrument}_path"],
                offset=sample["start"],
                duration=self._parent._chunk_duration,
                sample_rate=self._parent._sample_rate,
                waveform_name="waveform",
            ),
        )

    def compute_spectrogram(self, sample):
        """ Compute spectrogram of the given sample. """
        return dict(
            sample,
            **{
                self._spectrogram_key: compute_spectrogram_tf(
                    sample["waveform"],
                    frame_length=self._parent._frame_length,
                    frame_step=self._parent._frame_step,
                    spec_exponent=1.0,
                    window_exponent=1.0,
                )
            },
        )

    def filter_frequencies(self, sample):
        """ """
        return dict(
            sample,
            **{
                self._spectrogram_key: sample[self._spectrogram_key][
                    :, : self._parent._F, :
                ]
            },
        )

    def convert_to_uint(self, sample):
        """ Convert given sample from float to unit. """
        return dict(
            sample,
            **spectrogram_to_db_uint(
                sample[self._spectrogram_key],
                tensor_key=self._spectrogram_key,
                min_key=self._min_spectrogram_key,
                max_key=self._max_spectrogram_key,
            ),
        )

    def filter_infinity(self, sample):
        """ Filter infinity sample. """
        return tf.logical_not(tf.math.is_inf(sample[self._min_spectrogram_key]))

    def convert_to_float32(self, sample):
        """ Convert given sample from unit to float. """
        return dict(
            sample,
            **{
                self._spectrogram_key: db_uint_spectrogram_to_gain(
                    sample[self._spectrogram_key],
                    sample[self._min_spectrogram_key],
                    sample[self._max_spectrogram_key],
                )
            },
        )

    def time_crop(self, sample):
        """ """

        def start(sample):
            """ mid_segment_start """
            return tf.cast(
                tf.maximum(
                    tf.shape(sample[self._spectrogram_key])[0] / 2
                    - self._parent._T / 2,
                    0,
                ),
                tf.int32,
            )

        return dict(
            sample,
            **{
                self._spectrogram_key: sample[self._spectrogram_key][
                    start(sample) : start(sample) + self._parent._T, :, :
                ]
            },
        )

    def filter_shape(self, sample):
        """ Filter badly shaped sample. """
        return check_tensor_shape(
            sample[self._spectrogram_key],
            (self._parent._T, self._parent._F, self._parent._n_channels),
        )

    def reshape_spectrogram(self, sample):
        """ Reshape given sample. """
        return dict(
            sample,
            **{
                self._spectrogram_key: set_tensor_shape(
                    sample[self._spectrogram_key],
                    (self._parent._T, self._parent._F, self._parent._n_channels),
                )
            },
        )


class DatasetBuilder(object):
    """
    TO BE DOCUMENTED.
    """

    MARGIN: float = 0.5
    """ Margin at beginning and end of songs in seconds. """

    WAIT_PERIOD: int = 60
    """ Wait period for cache (in seconds). """

    def __init__(
        self,
        audio_params: Dict,
        audio_adapter: AudioAdapter,
        audio_path: str,
        random_seed: int = 0,
        chunk_duration: float = 20.0,
    ) -> None:
        """
        Default constructor.

        NOTE: Probably need for AudioAdapter.

        Parameters:
            audio_params (Dict):
                Audio parameters to use.
            audio_adapter (AudioAdapter):
                Audio adapter to use.
            audio_path (str):
            random_seed (int):
            chunk_duration (float):
        """
        # Length of segment in frames (if fs=22050 and
        # frame_step=512, then T=512 corresponds to 11.89s)
        self._T = audio_params["T"]
        # Number of frequency bins to be used (should
        # be less than frame_length/2 + 1)
        self._F = audio_params["F"]
        self._sample_rate = audio_params["sample_rate"]
        self._frame_length = audio_params["frame_length"]
        self._frame_step = audio_params["frame_step"]
        self._mix_name = audio_params["mix_name"]
        self._n_channels = audio_params["n_channels"]
        self._instruments = [self._mix_name] + audio_params["instrument_list"]
        self._instrument_builders = None
        self._chunk_duration = chunk_duration
        self._audio_adapter = audio_adapter
        self._audio_params = audio_params
        self._audio_path = audio_path
        self._random_seed = random_seed

        self.check_parameters_compatibility()

    def check_parameters_compatibility(self):
        if self._frame_length / 2 + 1 < self._F:
            raise ValueError(
                "F is too large and must be set to at most frame_length/2+1. Decrease F or increase frame_length to fix."
            )

        if (
            self._chunk_duration * self._sample_rate - self._frame_length
        ) / self._frame_step < self._T:
            raise ValueError(
                "T is too large considering STFT parameters and chunk duratoin. Make sure spectrogram time dimension of chunks is larger than T (for instance reducing T or frame_step or increasing chunk duration)."
            )

    def expand_path(self, sample):
        """ Expands audio paths for the given sample. """
        return dict(
            sample,
            **{
                f"{instrument}_path": tf.strings.join(
                    (self._audio_path, sample[f"{instrument}_path"]), SEPARATOR
                )
                for instrument in self._instruments
            },
        )

    def filter_error(self, sample):
        """ Filter errored sample. """
        return tf.logical_not(sample["waveform_error"])

    def filter_waveform(self, sample):
        """ Filter waveform from sample. """
        return {k: v for k, v in sample.items() if not k == "waveform"}

    def harmonize_spectrogram(self, sample):
        """ Ensure same size for vocals and mix spectrograms. """

        def _reduce(sample):
            return tf.reduce_min(
                [
                    tf.shape(sample[f"{instrument}_spectrogram"])[0]
                    for instrument in self._instruments
                ]
            )

        return dict(
            sample,
            **{
                f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"][
                    : _reduce(sample), :, :
                ]
                for instrument in self._instruments
            },
        )

    def filter_short_segments(self, sample):
        """ Filter out too short segment. """
        return tf.reduce_any(
            [
                tf.shape(sample[f"{instrument}_spectrogram"])[0] >= self._T
                for instrument in self._instruments
            ]
        )

    def random_time_crop(self, sample):
        """ Random time crop of 11.88s. """
        return dict(
            sample,
            **sync_apply(
                {
                    f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
                    for instrument in self._instruments
                },
                lambda x: tf.image.random_crop(
                    x,
                    (self._T, len(self._instruments) * self._F, self._n_channels),
                    seed=self._random_seed,
                ),
            ),
        )

    def random_time_stretch(self, sample):
        """ Randomly time stretch the given sample. """
        return dict(
            sample,
            **sync_apply(
                {
                    f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
                    for instrument in self._instruments
                },
                lambda x: random_time_stretch(x, factor_min=0.9, factor_max=1.1),
            ),
        )

    def random_pitch_shift(self, sample):
        """ Randomly pitch shift the given sample. """
        return dict(
            sample,
            **sync_apply(
                {
                    f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
                    for instrument in self._instruments
                },
                lambda x: random_pitch_shift(x, shift_min=-1.0, shift_max=1.0),
                concat_axis=0,
            ),
        )

    def map_features(self, sample):
        """ Select features and annotation of the given sample. """
        input_ = {
            f"{self._mix_name}_spectrogram": sample[f"{self._mix_name}_spectrogram"]
        }
        output = {
            f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
            for instrument in self._audio_params["instrument_list"]
        }
        return (input_, output)

    def compute_segments(self, dataset: Any, n_chunks_per_song: int) -> Any:
        """
        Computes segments for each song of the dataset.

        Parameters:
            dataset (Any):
                Dataset to compute segments for.
            n_chunks_per_song (int):
                Number of segment per song to compute.

        Returns:
            Any:
                Segmented dataset.
        """
        if n_chunks_per_song <= 0:
            raise ValueError("n_chunks_per_song must be positif")
        datasets = []
        for k in range(n_chunks_per_song):
            if n_chunks_per_song > 1:
                datasets.append(
                    dataset.map(
                        lambda sample: dict(
                            sample,
                            start=tf.maximum(
                                k
                                * (
                                    sample["duration"]
                                    - self._chunk_duration
                                    - 2 * self.MARGIN
                                )
                                / (n_chunks_per_song - 1)
                                + self.MARGIN,
                                0,
                            ),
                        )
                    )
                )
            elif n_chunks_per_song == 1:  # Take central segment.
                datasets.append(
                    dataset.map(
                        lambda sample: dict(
                            sample,
                            start=tf.maximum(
                                sample["duration"] / 2 - self._chunk_duration / 2, 0
                            ),
                        )
                    )
                )
        dataset = datasets[-1]
        for d in datasets[:-1]:
            dataset = dataset.concatenate(d)
        return dataset

    @property
    def instruments(self) -> Any:
        """
        Instrument dataset builder generator.

        Yields:
            Any:
                InstrumentBuilder instance.
        """
        if self._instrument_builders is None:
            self._instrument_builders = []
            for instrument in self._instruments:
                self._instrument_builders.append(
                    InstrumentDatasetBuilder(self, instrument)
                )
        for builder in self._instrument_builders:
            yield builder

    def cache(self, dataset: Any, cache: str, wait: bool) -> Any:
        """
        Cache the given dataset if cache is enabled. Eventually waits for
        cache to be available (useful if another process is already
        computing cache) if provided wait flag is `True`.

        Parameters:
            dataset (Any):
                Dataset to be cached if cache is required.
            cache (str):
                Path of cache directory to be used, None if no cache.
            wait (bool):
                If caching is enabled, True is cache should be waited.

        Returns:
            Any:
                Cached dataset if needed, original dataset otherwise.
        """
        if cache is not None:
            if wait:
                while not exists(f"{cache}.index"):
                    logger.info(f"Cache not available, wait {self.WAIT_PERIOD}")
                    time.sleep(self.WAIT_PERIOD)
            cache_path = os.path.split(cache)[0]
            os.makedirs(cache_path, exist_ok=True)
            return dataset.cache(cache)
        return dataset

    def build(
        self,
        csv_path: str,
        batch_size: int = 8,
        shuffle: bool = True,
        convert_to_uint: bool = True,
        random_data_augmentation: bool = False,
        random_time_crop: bool = True,
        infinite_generator: bool = True,
        cache_directory: Optional[str] = None,
        wait_for_cache: bool = False,
        num_parallel_calls: int = 4,
        n_chunks_per_song: float = 2,
    ) -> Any:
        """
        TO BE DOCUMENTED.
        """
        dataset = dataset_from_csv(csv_path)
        dataset = self.compute_segments(dataset, n_chunks_per_song)
        # Shuffle data
        if shuffle:
            dataset = dataset.shuffle(
                buffer_size=200000,
                seed=self._random_seed,
                # useless since it is cached :
                reshuffle_each_iteration=True,
            )
        # Expand audio path.
        dataset = dataset.map(self.expand_path)
        # Load waveform, compute spectrogram, and filtering error,
        # K bins frequencies, and waveform.
        N = num_parallel_calls
        for instrument in self.instruments:
            dataset = (
                dataset.map(instrument.load_waveform, num_parallel_calls=N)
                .filter(self.filter_error)
                .map(instrument.compute_spectrogram, num_parallel_calls=N)
                .map(instrument.filter_frequencies)
            )
        dataset = dataset.map(self.filter_waveform)
        # Convert to uint before caching in order to save space.
        if convert_to_uint:
            for instrument in self.instruments:
                dataset = dataset.map(instrument.convert_to_uint)
        dataset = self.cache(dataset, cache_directory, wait_for_cache)
        # Check for INFINITY (should not happen)
        for instrument in self.instruments:
            dataset = dataset.filter(instrument.filter_infinity)
        # Repeat indefinitly
        if infinite_generator:
            dataset = dataset.repeat(count=-1)
        # Ensure same size for vocals and mix spectrograms.
        # NOTE: could be done before caching ?
        dataset = dataset.map(self.harmonize_spectrogram)
        # Filter out too short segment.
        # NOTE: could be done before caching ?
        dataset = dataset.filter(self.filter_short_segments)
        # Random time crop of 11.88s
        if random_time_crop:
            dataset = dataset.map(self.random_time_crop, num_parallel_calls=N)
        else:
            # frame_duration = 11.88/T
            # take central segment (for validation)
            for instrument in self.instruments:
                dataset = dataset.map(instrument.time_crop)
        # Post cache shuffling. Done where the data are the lightest:
        # after croping but before converting back to float.
        if shuffle:
            dataset = dataset.shuffle(
                buffer_size=256, seed=self._random_seed, reshuffle_each_iteration=True
            )
        # Convert back to float32
        if convert_to_uint:
            for instrument in self.instruments:
                dataset = dataset.map(
                    instrument.convert_to_float32, num_parallel_calls=N
                )
        M = 8  # Parallel call post caching.
        # Must be applied with the same factor on mix and vocals.
        if random_data_augmentation:
            dataset = dataset.map(self.random_time_stretch, num_parallel_calls=M).map(
                self.random_pitch_shift, num_parallel_calls=M
            )
        # Filter by shape (remove badly shaped tensors).
        for instrument in self.instruments:
            dataset = dataset.filter(instrument.filter_shape).map(
                instrument.reshape_spectrogram
            )
        # Select features and annotation.
        dataset = dataset.map(self.map_features)
        # Make batch (done after selection to avoid
        # error due to unprocessed instrument spectrogram batching).
        dataset = dataset.batch(batch_size)
        return dataset
